
#include <iostream>
#include <functional>
#include <chrono>

#include "../Common/Eigen/Core"
#include "../Common/Eigen/QR"
#include "../Common/ParameterReader.hpp"
#include "../Common/TransientImage.hpp"

using Eigen::Vector2f;
using Eigen::Vector3f;
using Eigen::Vector2i;
using Eigen::Vector3i;
using Eigen::Matrix;
using Eigen::Matrix3f;

using namespace std;
using LibTransientImage::TransientImage;
using LibTransientImage::Exception;


////////////////////////////////
//
//      Helpers
//
////////////////////////////////


// x should be in [-1, 1]
float MitchellFilter(float x)
{
	const auto B = 1.f/3.f;
	const auto C = 1.f/3.f;
	x = std::abs(2*x);
	if (x > 1) //x>0.5
		return ((-B -6*C)*x*x*x +(6*B + 30*C)*x*x +
				(-12*B -48*C)*x + (8*B + 24*C))*
				(1.f/6.f);
	else // x<0.5
		return ((12 -9*B -6*C)*x*x*x +
				(-18 + 12*B + 6*C)*x*x + (6 -2*B))*
				(1.f/6.f);
	// x>1 is ignored here, the user is responsible to keep x in [-1, 1]
}

// computes a mapping p -> q
std::function<Vector2f(Vector2f)> ComputeHomographyFromCorrespondences(
	Vector2f p0, Vector2f p1, Vector2f p2, Vector2f p3,
	Vector2f q0, Vector2f q1, Vector2f q2, Vector2f q3)
{
	auto xrow = [](Vector2f p, Vector2f q) -> Matrix<float, 1, 8>
	{
		return(
			Matrix<float, 1, 8>() << p.x(), p.y(), 1, 0, 0, 0, -p.x()*q.x(), -p.y()*q.x()
			).finished();
	};

	auto yrow = [](Vector2f p, Vector2f q) -> Matrix<float, 1, 8>
	{
		return(
			Matrix<float, 1, 8>() << 0, 0, 0, p.x(), p.y(), 1, -p.x()*q.y(), -p.y()*q.y()
			).finished();
	};

	auto v = (Matrix<float, 8, 1>() << q0.x(), q0.y(), q1.x(), q1.y(), q2.x(), q2.y(), q3.x(), q3.y()).finished();
	auto M = (Matrix<float, 8, 8>() <<
		xrow(p0, q0), yrow(p0, q0),
		xrow(p1, q1), yrow(p1, q1),
		xrow(p2, q2), yrow(p2, q2),
		xrow(p3, q3), yrow(p3, q3)
		).finished();

	auto b = M.colPivHouseholderQr().solve(v);

	auto H = (Matrix3f() << b(0), b(1), b(2), b(3), b(4), b(5), b(6), b(7), 1).finished();
	
	// return a lambda that computes the transformation of a vector
	return [H](Vector2f p)
	{
		Vector3f q = H*Vector3f(p.x(), p.y(), 1.f);
		return Vector2f(q.x()/q.z(), q.y()/q.z());
	};
}

// maps [p0, p1] -> [q0, q1]
std::function<float(float)> Map1dInterval(float p0, float p1, float q0, float q1)
{
	return [p0, p1, q0, q1](float v)
	{
		return (v-p0)/(p1-p0)*(q1-q0) + q0;
	};
}


////////////////////////////////
//
//      Main Program
//
////////////////////////////////

int Round(float x)
{
	return static_cast<int>(x+0.5);
}
int Round(double x)
{
	return static_cast<int>(x+0.5);
}

// todo: add cameraPosition
TransientImage ResampleImage(const TransientImage& inputImage,
	unsigned int tResOut, unsigned int uResOut, unsigned int vResOut,
	float tMinOut, float tMaxOut,
	Vector2f q0, Vector2f q1, Vector2f q2, Vector2f q3,
	float spatialFilterWidth, float temporalFilterWidth)
{
	// todo: consider transformation of input image

	const auto halfSpatialFilterWidth = spatialFilterWidth / 2.f;
	const auto halfTemporalFilterWidth = temporalFilterWidth / 2.f;

	const auto tResIn = inputImage.header.numBins;
	const auto uResIn = inputImage.pixelInterpretationBlock.uResolution;
	const auto vResIn = inputImage.pixelInterpretationBlock.vResolution;

	const auto tMinIn = inputImage.header.tMin;
	const auto tMaxIn = inputImage.header.tMin + inputImage.header.numBins*inputImage.header.tDelta;

	TransientImage output;
	
	output.header.pixelMode = 10;
	output.header.numPixels = uResOut*vResOut;
	output.header.numBins = tResOut;
	output.header.tMin = tMinOut;
	output.header.tDelta = (tMaxOut-tMinOut) / tResOut;
	output.header.pixelInterpretationBlockSize = sizeof(output.pixelInterpretationBlock);

	output.data.resize(tResOut*uResOut*vResOut);

	output.pixelInterpretationBlock.uResolution = uResOut;
	output.pixelInterpretationBlock.vResolution = vResOut;
	output.pixelInterpretationBlock.topLeft    = {q0.x(), 0, q0.y()};
	output.pixelInterpretationBlock.topRight   = {q1.x(), 0, q1.y()};
	output.pixelInterpretationBlock.bottomRight= {q2.x(), 0, q2.y()};
	output.pixelInterpretationBlock.bottomLeft = {q3.x(), 0, q3.y()};
	output.pixelInterpretationBlock.laserPosition = {0, 0, 0};

	output.imageProperties = inputImage.imageProperties; // copy meta info


	/*
	There are a total of three transformations:
	 - pixels to old wall positions
	 - old wall positions to new wall positions
	 - from new wall positions to new pixels
	*/

	// we assume, all points lie on the Y=0 plane
	if(    inputImage.pixelInterpretationBlock.topLeft[1] != 0
		|| inputImage.pixelInterpretationBlock.topRight[1] != 0
		|| inputImage.pixelInterpretationBlock.bottomRight[1] != 0
		|| inputImage.pixelInterpretationBlock.bottomLeft[1] != 0)
		throw Exception("Input image is not in the Y=0 plane");

	auto OutputWallFromOutputPixel = ComputeHomographyFromCorrespondences(
	{0, 0},
	{output.pixelInterpretationBlock.uResolution, 0},
	{output.pixelInterpretationBlock.uResolution, output.pixelInterpretationBlock.vResolution},
	{0, output.pixelInterpretationBlock.vResolution},

	{output.pixelInterpretationBlock.topLeft[0],     output.pixelInterpretationBlock.topLeft[2]},
	{output.pixelInterpretationBlock.topRight[0],    output.pixelInterpretationBlock.topRight[2]},
	{output.pixelInterpretationBlock.bottomRight[0], output.pixelInterpretationBlock.bottomRight[2]},
	{output.pixelInterpretationBlock.bottomLeft[0],  output.pixelInterpretationBlock.bottomLeft[2]});

	auto InputWallFromOutputWall = ComputeHomographyFromCorrespondences(
	
	{output.pixelInterpretationBlock.topLeft[0],     output.pixelInterpretationBlock.topLeft[2]},
	{output.pixelInterpretationBlock.topRight[0],    output.pixelInterpretationBlock.topRight[2]},
	{output.pixelInterpretationBlock.bottomRight[0], output.pixelInterpretationBlock.bottomRight[2]},
	{output.pixelInterpretationBlock.bottomLeft[0],  output.pixelInterpretationBlock.bottomLeft[2]},

	{inputImage.pixelInterpretationBlock.topLeft[0],     inputImage.pixelInterpretationBlock.topLeft[2]},
	{inputImage.pixelInterpretationBlock.topRight[0],    inputImage.pixelInterpretationBlock.topRight[2]},
	{inputImage.pixelInterpretationBlock.bottomRight[0], inputImage.pixelInterpretationBlock.bottomRight[2]},
	{inputImage.pixelInterpretationBlock.bottomLeft[0],  inputImage.pixelInterpretationBlock.bottomLeft[2]});

	auto InputPixelFromInputWall = ComputeHomographyFromCorrespondences(
	{inputImage.pixelInterpretationBlock.topLeft[0],     inputImage.pixelInterpretationBlock.topLeft[2]},
	{inputImage.pixelInterpretationBlock.topRight[0],    inputImage.pixelInterpretationBlock.topRight[2]},
	{inputImage.pixelInterpretationBlock.bottomRight[0], inputImage.pixelInterpretationBlock.bottomRight[2]},
	{inputImage.pixelInterpretationBlock.bottomLeft[0],  inputImage.pixelInterpretationBlock.bottomLeft[2]},

	{0, 0},
	{inputImage.pixelInterpretationBlock.uResolution, 0},
	{inputImage.pixelInterpretationBlock.uResolution, inputImage.pixelInterpretationBlock.vResolution},
	{0, inputImage.pixelInterpretationBlock.vResolution});


	// maps timestamps
	auto ToDistance = Map1dInterval(0.f, static_cast<float>(tResOut), static_cast<float>(tMinOut), static_cast<float>(tMaxOut));
	auto FromDistance = Map1dInterval(static_cast<float>(tMinIn), static_cast<float>(tMaxIn), 0.f, static_cast<float>(tResIn));
	
	chrono::high_resolution_clock clock;
	const auto tStart = clock.now();
	auto tLast = tStart;

	// loop over the image and generate the value of every pixel:
	// x,y,t output coordinates - X,Y,T input coordinates
	for(int vOut = 0; vOut < output.pixelInterpretationBlock.vResolution; ++vOut)
	{
		for(int uOut = 0; uOut < output.pixelInterpretationBlock.vResolution; ++uOut)
		{
			const auto pixelPosOut = Vector2f(uOut+0.5, vOut+0.5); // from indices to positions, the half pixel offset must be added!
			const auto pixelPosIn = InputPixelFromInputWall(InputWallFromOutputWall(OutputWallFromOutputPixel(pixelPosOut)));
						
			//temporal offset:
			auto tOffset = 0.f; //TODO: compute this from camera and laser position

			// now copy the temporal dimension
			for(int tOut = 0u; tOut < tResOut; ++tOut)
			{
				auto tIn = FromDistance(ToDistance(tOut));

				// loop over region of input image to apply filter
				output.AccessPixel(tOut, uOut, vOut) = 0;
				
				// we know, where we want to sample the input image. now we have to 
				
				float weight = 0;

				// Sample: position where we want to sample the image; SamplePix: center of the pixel that is hit by Sample
				for(auto uSample = max(0.f, pixelPosIn.x()-halfSpatialFilterWidth); uSample < min(float(uResIn), pixelPosIn.x()+halfSpatialFilterWidth); uSample += 1.f)
				{
					const auto uSamplePix = static_cast<int>(uSample) + 0.5;
					const auto uWeight = MitchellFilter((uSamplePix-pixelPosIn.x())/halfSpatialFilterWidth);

					for(auto vSample = max(0.f, pixelPosIn.y()-halfSpatialFilterWidth); vSample < min(float(vResIn), pixelPosIn.y()+halfSpatialFilterWidth); vSample += 1.f)
					{
						const auto vSamplePix = static_cast<int>(vSample) + 0.5;
						const auto vWeight = MitchellFilter((vSamplePix-pixelPosIn.y())/halfSpatialFilterWidth);

						for(auto tSample = max(0.f, tIn-halfTemporalFilterWidth); tSample < min(float(tResIn), tIn+halfTemporalFilterWidth); tSample += 1.f)
						{
							const auto tSamplePix = static_cast<int>(tSample) + 0.5;
							const auto tWeight = MitchellFilter((tSamplePix-tIn)/halfTemporalFilterWidth);

							output.AccessPixel(tOut, uOut, vOut) += inputImage.AccessPixel(static_cast<int>(tSamplePix), static_cast<int>(uSamplePix), static_cast<int>(vSamplePix)) * uWeight*vWeight*tWeight;
							weight += uWeight*vWeight*tWeight;
							//cout << tOut << ", " << uOut << ", " << vOut << " - " << static_cast<int>(tSamplePix) << ", " << static_cast<int>(uSamplePix) << ", " << static_cast<int>(vSamplePix) << endl;
						}
					}
				}
				if(0.f == weight)
					output.AccessPixel(tOut, uOut, vOut) = 0.f;
				else
					output.AccessPixel(tOut, uOut, vOut) = max(output.AccessPixel(tOut, uOut, vOut)/weight, 0.f);
			}
		}
		auto now = clock.now();
		if(chrono::duration_cast<chrono::milliseconds>(now-tLast).count() > 500)
		{
			cout << "\r" << vOut*100/vResOut << "%    ";
			tLast = now;
		}
	}
	cout << "\rdone. Total time: " << chrono::duration_cast<chrono::seconds>(clock.now()-tStart).count() << "s" << endl;

	return output;
}


int main(int argc, char* argv[])
{
	std::string inputFilename;
	const std::string usage = "SetupConvert -i \"Input.ti\" -res 360 128 128 -t 0 4 -q0 150 400 -q1 450 420 -q2 450 120 -q3 150 250 -o \"Output.ti\" -sf 5 -tf 2";
	try
	{

		// we first read the input file to initialize parameters that can later be replaced by command line parameters
		for(auto& p : ParameterReader(argc, argv))
		{
			if("-i" == p.Name())
			{
				inputFilename = p.ReadString();
				if(inputFilename.find(".ti") == string::npos)
					throw Exception(inputFilename+" is not a *.ti file");
			}
		}
		if(inputFilename.empty())
		{
			throw Exception("no input filename given. Example usage: "+usage);
		}
		
		TransientImage input(inputFilename);

		// input parameters (copy from input file)
		auto tResOut = input.header.numBins;
		auto uResOut = input.pixelInterpretationBlock.uResolution;
		auto vResOut = input.pixelInterpretationBlock.vResolution;
		auto tMinOut = input.header.tMin;
		auto tMaxOut = input.header.tMin+input.header.tDelta*input.header.numBins;
		Vector2f q0(input.pixelInterpretationBlock.topLeft[0], input.pixelInterpretationBlock.topLeft[2]);
		Vector2f q1(input.pixelInterpretationBlock.topRight[0], input.pixelInterpretationBlock.topRight[2]);
		Vector2f q2(input.pixelInterpretationBlock.bottomRight[0], input.pixelInterpretationBlock.bottomRight[2]);
		Vector2f q3(input.pixelInterpretationBlock.bottomLeft[0], input.pixelInterpretationBlock.bottomLeft[2]);

		std::string outputFilename("Converted.ti");
		float temporalFilterWidth = 2.f;
		float spatialFilterWidth = 5.f;

		// if the filters are not specified, we compute reasonable parameters ourself
		bool spatialFilterSet = false;
		bool temporalFilterSet = false;

		// read remaining parameters
		for(auto& p : ParameterReader(argc, argv))
		{
			if("-i" == p.Name())
			{
				// ignore this parameter to not trigger the exception
				p.ReadString();
			}
			else if("-res" == p.Name())
			{
				tResOut = p.ReadInt();
				uResOut = p.ReadInt();
				vResOut = p.ReadInt();
			}
			else if("-t" == p.Name())
			{
				tMinOut = p.ReadFloat();
				tMaxOut = p.ReadFloat();
			}
			else if("-q0" == p.Name())
			{
				q0[0] = p.ReadFloat();
				q0[1] = p.ReadFloat();
			}
			else if("-q1" == p.Name())
			{
				q1[0] = p.ReadFloat();
				q1[1] = p.ReadFloat();
			}
			else if("-q2" == p.Name())
			{
				q2[0] = p.ReadFloat();
				q2[1] = p.ReadFloat();
			}
			else if("-q3" == p.Name())
			{
				q3[0] = p.ReadFloat();
				q3[1] = p.ReadFloat();
			}
			else if("-o" == p.Name())
			{
				outputFilename = p.ReadString();
			}
			else if("-sf" == p.Name())
			{
				spatialFilterWidth = p.ReadFloat();
				spatialFilterSet = true;
			}
			else if("-tf" == p.Name())
			{
				temporalFilterWidth = p.ReadFloat();
				temporalFilterSet = true;
			}
			else
			{
				throw Exception("Unknown Parameter '"+p.Name()+"'. Usage Example: "+usage);
			}
		}

		if(!spatialFilterSet)
		{
			// see FilterSize.lyx
			// here we ignore the y resolution, because we have only a single filter size.
			spatialFilterWidth = max(4.f, 4.f*static_cast<float>(input.pixelInterpretationBlock.uResolution)/static_cast<float>(uResOut));
			cout << "new spatial filter size computed: " << spatialFilterWidth << endl;
		}
		if(!temporalFilterSet)
		{
			// see FilterSize.lyx
			// here we ignore the y resolution, because we have only a single filter size.
			temporalFilterWidth = max(4.f, 4.f*static_cast<float>(input.header.numBins)/static_cast<float>(tResOut));
			cout << "new temporal filter size computed: " << temporalFilterWidth << endl;
		}
	


		cout << "Image read, starting conversion" << endl;
		auto output = ResampleImage(inputFilename,
			tResOut, uResOut, vResOut,
			tMinOut, tMaxOut,
			q0, q1, q2, q3, spatialFilterWidth, temporalFilterWidth);
		output.WriteFile(outputFilename);
		cout << "\nfile written." << endl;
	}
	catch(std::exception &ex)
	{
		std::cout << "Error: " << inputFilename << ": " << ex.what() << std::endl;
	}
	return 0;
}
